import os
import cv2
import lmdb
import torch
from jpeg2dct.numpy import load, loads
import numpy as np
import torch.nn as nn
import gc
import math
import time
import copy
import logging
import torch.optim as optim
import torch.distributed as dist
import pickle
import six
from glob import glob
from PIL import Image
from tqdm import tqdm
from torch.autograd import Variable
from torch.cuda.amp import autocast
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler  # need pytorch>1.6
from losses import DiceLoss, FocalLoss, SoftCrossEntropyLoss, LovaszLoss
import albumentations as A
from dtd import *
from albumentations.pytorch import ToTensorV2
import torchvision
import argparse
import tempfile
from functools import partial
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath

parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='./')  # root to the dir of lmdb files
parser.add_argument('--pth', type=str, default='dtd.pth')
parser.add_argument('--lmdb_name', type=str, default='DocTamperV1-FCD')
parser.add_argument('--minq', type=int, default=75)
args = parser.parse_args()


class TamperDataset(Dataset):
    # minq是jpeg压缩比率
    def __init__(self, roots, mode, minq=95, qtb=90, max_readers=64):
        self.envs = lmdb.open(roots, max_readers=max_readers, readonly=True, lock=False, readahead=False, meminit=False)
        with self.envs.begin(write=False) as txn:
            self.nSamples = int(txn.get('num-samples'.encode('utf-8')))
        self.max_nums = self.nSamples
        self.minq = minq
        self.mode = mode
        with open('qt_table.pk', 'rb') as fpk:
            pks = pickle.load(fpk)
        self.pks = {}
        for k, v in pks.items():
            self.pks[k] = torch.LongTensor(v)
        buf = roots.split('DocTamper', 1)
        res = buf[len(buf) - 1]
        with open('pks/' + 'DocTamper' + res + '_%d.pk' % minq, 'rb') as f:
            self.record = pickle.load(f)
        self.hflip = torchvision.transforms.RandomHorizontalFlip(p=1.0)
        self.vflip = torchvision.transforms.RandomVerticalFlip(p=1.0)
        self.totsr = ToTensorV2()
        self.toctsr = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize(mean=(0.485, 0.455, 0.406),
                                                                                       std=(0.229, 0.224, 0.225))])

    def __len__(self):
        return self.max_nums

    def __getitem__(self, index):
        with self.envs.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key.encode('utf-8'))
            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            im = Image.open(buf)
            lbl_key = 'label-%09d' % index
            lblbuf = txn.get(lbl_key.encode('utf-8'))
            mask = (cv2.imdecode(np.frombuffer(lblbuf, dtype=np.uint8), 0) != 0).astype(np.uint8)
            H, W = mask.shape
            record = self.record[index]
            choicei = len(record) - 1
            q = int(record[-1])
            use_qtb = self.pks[q]
            if choicei > 1:
                q2 = int(record[-3])
                use_qtb2 = self.pks[q2]
            if choicei > 0:
                q1 = int(record[-2])
                use_qtb1 = self.pks[q1]
            mask = self.totsr(image=mask.copy())['image']
            with tempfile.NamedTemporaryFile(delete=True) as tmp:
                im = im.convert("L")
                if choicei > 1:
                    im.save(tmp, "JPEG", quality=q2)
                    im = Image.open(tmp)
                if choicei > 0:
                    im.save(tmp, "JPEG", quality=q1)
                    im = Image.open(tmp)
                im.save(tmp, "JPEG", quality=q)
                # jpg = jpegio.read(tmp.name)
                # dct = jpg.coef_arrays[0].copy()
                dct_y, dct_cb, dct_cr = load(tmp.name)

                dct_y_reconstructed = np.zeros((512, 512))
                for i in range(64):
                    for j in range(64):
                        dct_y_reconstructed[i * 8:(i + 1) * 8, j * 8:(j + 1) * 8] = dct_y[i, j].reshape(8, 8)
                dct = dct_y_reconstructed.copy()
                # print(dct.shape)
                im = im.convert('RGB')
            return {
                'image': self.toctsr(im),
                'label': mask.long(),
                'rgb': np.clip(np.abs(dct), 0, 20),
                'q': use_qtb,
                'i': q
            }


test_data = TamperDataset(args.data_root + args.lmdb_name, False, minq=args.minq)


def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter("[%(asctime)s][%(filename)s][%(levelname)s] %(message)s")
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def second2time(second):
    if second < 60:
        return str('{}'.format(round(second, 4)))
    elif second < 60 * 60:
        m = second // 60
        s = second % 60
        return str('{}:{}'.format(int(m), round(s, 1)))
    elif second < 60 * 60 * 60:
        h = second // (60 * 60)
        m = second % (60 * 60) // 60
        s = second % (60 * 60) % 60
        return str('{}:{}:{}'.format(int(h), int(m), int(s)))


def inial_logger(file):
    logger = logging.getLogger('log')
    logger.setLevel(level=logging.DEBUG)
    formatter = logging.Formatter('%(message)s')
    file_handler = logging.FileHandler(file)
    file_handler.setLevel(level=logging.INFO)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.DEBUG)
    stream_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    return logger


class IOUMetric:
    # 初始化类，设置类别数量并初始化混淆矩阵(hist)
    def __init__(self, num_classes=10):
        self.num_classes = num_classes  # 类别数量
        self.hist = np.zeros((num_classes, num_classes))  # 初始化混淆矩阵，大小为 num_classes x num_classes

    # 计算快速的混淆矩阵
    def _fast_hist(self, label_pred, label_true):
        mask = (label_true >= 0) & (label_true < self.num_classes)  # 创建掩码以过滤有效的标签值
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist  # 返回计算后的混淆矩阵

    # 向混淆矩阵添加批量预测和真实标签数据
    def add_batch(self, predictions, gts):
        for lp, lt in zip(predictions, gts):  # 遍历预测和真实标签
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())  # 更新混淆矩阵

    # 评估并返回多个性能指标
    def evaluate(self):
        acc = np.diag(self.hist).sum() / self.hist.sum()  # 计算总体准确率
        acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)  # 计算每类的准确率
        acc_cls = np.nanmean(acc_cls)  # 计算准确率的平均值，忽略 NaN 值
        iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))  # 计算每类的 IoU
        mean_iu = np.nanmean(iu)  # 计算 IoU 的平均值，忽略 NaN 值
        freq = self.hist.sum(axis=1) / self.hist.sum()  # 计算每类的出现频率
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()  # 计算频率加权平均准确率
        return acc, acc_cls, iu, mean_iu, fwavacc  # 返回计算的性能指标


model = seg_dtd('', 2).cuda()
model = torch.nn.DataParallel(model)


# 定义一个用于评估模型性能的函数
def eval_net_dtd(model, test_data, plot=False, device='cuda'):
    # 使用给定的测试数据创建一个数据加载器
    train_loader1 = DataLoader(dataset=test_data, batch_size=16, num_workers=4, shuffle=False)
    # 初始化Lovasz损失函数
    LovaszLoss_fn = LovaszLoss(mode='multiclass')
    # 初始化软交叉熵损失函数
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    # 加载模型权重
    ckpt = torch.load(args.pth, map_location='cpu')
    # 将模型权重加载到模型中
    model.load_state_dict(ckpt['state_dict'])
    # 将模型设置为评估模式
    model.eval()
    # 初始化IOU度量
    iou = IOUMetric(2)
    # 初始化精确度和召回率列表
    precisons = []
    recalls = []
    # 不计算梯度，以加速和减少内存消耗
    with torch.no_grad():
        # 遍历数据加载器中的所有批次
        for batch_idx, batch_samples in enumerate(tqdm(train_loader1)):
            # 获取批次数据
            data, target, dct_coef, qs, q = batch_samples['image'], batch_samples['label'], batch_samples['rgb'], \
                                             batch_samples['q'], batch_samples['i']
            # 将数据和标签转移到指定的设备
            data, target, dct_coef, qs = data.to(device), target.to(device), dct_coef.to(device), qs.unsqueeze(1).to(device)
            # 通过模型获取预测结果
            pred = model(data, dct_coef, qs)
            # 获取预测类别
            predt = pred.argmax(1)
            # 将预测结果和目标转移到CPU
            pred = pred.cpu().data.numpy()
            targt = target.squeeze(1)
            # 计算匹配的像素数
            matched = (predt * targt).sum((1, 2))
            # 计算预测和目标中的像素总数
            pred_sum = predt.sum((1, 2))
            target_sum = targt.sum((1, 2))
            # 计算并存储精确度和召回率
            precisons.append((matched / (pred_sum + 1e-8)).mean().item())
            recalls.append((matched / target_sum).mean().item())
            # 更新iou度量
            pred = np.argmax(pred, axis=1)
            iou.add_batch(pred, target.cpu().data.numpy())
        # 计算最终的度量值
        acc, acc_cls, iu, mean_iu, fwavacc = iou.evaluate()
        # 计算平均精确度和召回率
        precisons = np.array(precisons).mean()
        recalls = np.array(recalls).mean()
        # 输出评估结果
        print('[val] iou:{} pre:{} rec:{} f1:{}'.format(iu, precisons, recalls,
                                                        (2 * precisons * recalls / (precisons + recalls + 1e-8))))



eval_net_dtd(model, test_data)
